Cross Correlation with CovidEstim and Wastewater Data

state <- "ma"

priors_versions <- c("v1", "v2", "v3", "v4")


versions <- tibble(
  version = c("v1", "v2", "v3", "v4"),
  vlabel = c(
    "$Priors\\,Do\\,Not\\,Vary\\,by\\,County\\,and\\,Date$", 
    "$\\beta$ Centered at Empirical Value",
    "$P(S_1|untested)$ and $\\beta$ Centered at Empirical Values",
    "$P(S_1|untested)$ Centered at Empirical Value")
)


state_corrected_path <- here("analysis/results/adj_biweekly_county/", state, "/")


################################
# ESTIMATED
################################
dates <- readRDS(here("data/date_to_biweek.RDS")) %>%
  group_by(biweek) 

corrected <- map_df(priors_versions, ~readRDS(
        paste0(state_corrected_path, "adj_",
               .x, 
               ".RDS")) %>% 
          mutate(version = .x)) %>%
  left_join(dates, relationship = "many-to-many")

corrected <- corrected %>%
  left_join(versions)

covidestim <- readRDS(here("data/covidestim/covidestim_biweekly_all_counties.RDS")) %>%
  select(-c(date,week)) %>%
  distinct() 

waste <- readRDS(here("data/county_level", "biobot_wastewater_county.RDS"))

# fill in missing values with rolling mean
waste <- waste %>% 
  group_by(fips) %>%
  # keep track of min biweek make sure to not fill in values
  # before first data point when using rolled mean
  mutate(min_biweek = min(biweek)) %>%
  ungroup() %>%
  pivot_wider(names_from = biweek,
              values_from = mean_conc) %>% 
  pivot_longer(cols = -c(fips,state, name, min_biweek), 
               names_to = "biweek", 
               values_to = "mean_conc") %>%
  mutate(biweek = as.numeric(biweek)) %>%
  group_by(fips) %>%
  arrange(biweek) %>%
  mutate(rolled_mean = RcppRoll::roll_mean(mean_conc,
                                           n = 4,
                                           na.rm = TRUE,
                                           fill = NA),
         mean_conc = ifelse(is.na(mean_conc) & biweek >= min_biweek, 
                            rolled_mean, mean_conc)) %>% ungroup()

wjoined <- corrected %>%
  # set Nantucket, Duke fips to Nantucket since we have wastewater data
  # for Nantucket
  mutate(fips = ifelse(grepl("25019", fips), "25019", fips)) %>%
  inner_join(waste) %>%
  left_join(covidestim, relationship = "many-to-many") %>%
  group_by(biweek) %>%
  mutate(mindate = min(date),
         maxdate = max(date)) %>%
  ungroup()



counties_with_waste <- wjoined %>%
  group_by(fips) %>%
  mutate(obs_w_notna = sum(!is.na(mean_conc))) %>%
  filter(obs_w_notna != 0) %>%
  pull(fips) %>%
  unique()
waste %>% filter(fips %in% corrected$fips) %>%
  ggplot(aes(x=biweek,y=mean_conc)) +
  geom_point() +
  geom_line() +
  geom_line(aes(y = rolled_mean), color = "darkred") +
  facet_wrap(~fips)

Wastewater

wjoined %>%
  group_by(fips) %>%
  summarize(obs_w_notna = sum(!is.na(mean_conc))) %>%
  arrange(desc(obs_w_notna))

Relationship between Observed Cases and Effective Wastewater Concentration

ma_observed <- readRDS(here("data/county_level/ma/ma_county_biweekly.RDS")) %>%
  mutate(fips = ifelse(grepl("25019", fips), "25019", fips))

ma_observed <- ma_observed %>%
  left_join(waste[,colnames(waste)!="date"]) %>%
  filter(fips %in% counties_with_waste)  %>%
  group_by(fips) %>%
  mutate(positive = positive/population) %>%
  ungroup()

adj <- max(ma_observed$mean_conc, na.rm=TRUE)/ 
  max(ma_observed$positive, na.rm=TRUE)  


# trends in wastewater and observed cases 
ma_observed %>%
  group_by(fips) %>%
 #  mutate(positive = positive/population) %>%
  # mutate(adj = ifelse( max(positive, na.rm=TRUE) ==0, 0,
  #                      max(mean_conc,na.rm = TRUE)/ max(positive, na.rm=TRUE))) %>%
  ungroup() %>%
  ggplot(aes(x=date, y = positive*adj)) +
  geom_point(alpha = .3, size = .8) +
  geom_line(aes(color = 'Positive Cases Normalized by Population')) + 
  geom_line(aes(y=mean_conc, color = 'Wastewater')) +
  geom_point(aes(y=mean_conc), color = "darkred",
             alpha = .3,
             size = .8) +
  facet_wrap(~fips,
             ncol = 1,
             scales="free")   +
  scale_y_continuous(sec.axis = sec_axis(
    trans = ~ . / (adj),
    name = 'Positive Cases Normalized by Population Size')) +
  theme_bw()+
  theme(axis.title = element_text(size = 18),
        legend.position = "top") +
  scale_color_manual(name = '',
                    values=c('Wastewater'='darkred',
                             'Positive Cases Normalized by Population' = 'black')) +
  guides(color = guide_legend(override.aes = list(linewidth = 3))) +
  labs(y = "Effective Concentration of SARS-CoV-2")

Look at time windows on different scales to better see trends before Omicron.

# trends in wastewater and observed cases  (split into time windows)
ma_observed %>%
  filter(!is.na(mean_conc)) %>%
   mutate(before = ifelse(date <= mdy("10-01-2021"),
                          "Before 10-01-2021", "After 10-01-2021"),
          before = factor(before, levels = c(
            "Before 10-01-2021","After 10-01-2021"))) %>%
  # group_by(fips) %>%
  # mutate(adj = ifelse( max(positive, na.rm=TRUE) ==0, 0,
  #                      max(mean_conc,na.rm = TRUE)/ max(positive, na.rm=TRUE))) %>%
  # ungroup() %>%
  ggplot(aes(x=date, y = positive*adj)) +
  geom_point(alpha = .3, size = .8) +
  geom_line(aes(color =  'Positive Cases Normalized by Population')) + 
  geom_line(aes(y=mean_conc, color = 'Wastewater')) +
  geom_point(aes(y=mean_conc), color = "darkred", alpha = .3, size = .8) +
  facet_wrap(fips~before,
             ncol = 2,
             scales="free")   +
  scale_y_continuous(sec.axis = sec_axis(trans = ~ . / (adj),
                                         name = 'Positive Cases')) +
  theme_bw()+
  theme(axis.title = element_text(size = 18),
         legend.position = "top") +
  scale_color_manual(name = '',
                    values=c('Wastewater'='darkred',
                             'Positive Cases Normalized by Population' = 'black')) +
  guides(color = guide_legend(override.aes = list(linewidth = 3)))+
  labs(y = "Effective Concentration of SARS-CoV-2")

wjoined %>%
  #  filter(fips == county_fips  & date >= begin_date & date <= end_date) %>%
    ggplot() +
    geom_ribbon(aes(x = date, 
               ymin = exp_cases_lb*adj,
               ymax = exp_cases_ub*adj,
               fill = vlabel),
               alpha = .7) +
    geom_line(data = waste,
              aes(x = date, y =mean_conc ),
              color = "#DB4048",
              size = 1.1) +
    geom_point(data = waste,
              aes(x = date, y =mean_conc ),
              color = "#DB4048",
              alpha = .5,
              size = 1.2) +
    facet_wrap(~vlabel, labeller= as_labeller(TeX, default = label_parsed)) +
    scale_fill_manual(values = pal) +
    theme_bw() +
    theme(
      legend.position = "none",
      plot.title = element_text(face = "bold", size = 16, hjust = .5),
      axis.title = element_text(size = 18),
      strip.text = element_text(size = 14)
    )+
    scale_y_continuous(sec.axis = sec_axis(~./adj,
                                       name = "Corrected Infection Estimates",
                                       labels = comma),
                       labels = comma) +
    labs(y = "Effective Concentration Rolling Average",
         title = custom_title) +
    scale_x_date(date_labels = "%b %Y")
pal <- c("#10BAC5", "#1B10C5", "#EFB719", "#900C3F")

# pal <- c("#74A09F", "#A0748B", 
#          "#748BA0", "#A08974",
#          "#D49E9F", "#D4B89E", "#AFCFE5")
# 
# 


compare_county_wastewater <- function(county_fips, 
                                      end_date ="2022-02-12",
                                      w_data,
                                      option = "ribbon") {
  
  county_name <- w_data %>% 
    filter(fips == county_fips) %>%
    pull(name) %>%
    na.omit() %>%
    unique()
    
  custom_title = paste0(
    "Comparing Wastewater Concentration to Corrected Estimates:\n",
                 county_name,
    ", FIPS: ",
    county_fips)
  
  end_date <- ymd(end_date)

  
  begin_date <- w_data %>%
    filter(fips == county_fips & date <= end_date & !is.na(mean_conc)) %>%
    pull(date) %>%
    min()
  
  if(option == "ribbon") { 
    adj <- w_data %>%
      filter(fips == county_fips & date <= end_date) %>%
      pull(exp_cases_median) %>% max(na.rm = TRUE) }
  else {
     adj <- w_data %>%
       filter(fips == county_fips & date <= end_date) %>% pull(exp_cases_median) %>% max(na.rm = TRUE)
  }
  conc_max <- w_data %>%filter(fips == county_fips & date <= end_date) %>% pull(mean_conc) %>% max(na.rm=TRUE)
  
  adj <- conc_max/adj
  
  waste_dat <- w_data %>% 
    filter(fips == county_fips & date >= begin_date & date <= end_date) %>%
    select(fips, date, biweek, mean_conc) %>%
    group_by(biweek, fips) %>%
    summarize(date = min(date),
              mean_conc = unique(mean_conc)) %>%
    distinct()
  
  if(option == "ribbon") {
     w_data %>%
        plot_ribbon(waste_df = waste_dat, adj, county_fips,
                    begin_date, end_date, custom_title)
  }
  else if(option == "line") {
     w_data %>%
        plot_line(waste_df = waste_dat, adj,county_fips,
                    begin_date, end_date, custom_title)
  }
 
}



plot_ribbon <- function(w_data, waste_df,
                        adj,   county_fips, 
                        begin_date, end_date,
                        custom_title) {
 
  w_data %>%
    filter(fips == county_fips  & date >= begin_date & date <= end_date) %>%
    ggplot() +
      geom_ribbon(aes(x = date, 
                 ymin = exp_cases_lb*adj,
                 ymax = exp_cases_ub*adj,
                 fill = vlabel),
                 alpha = .4) +
      geom_line(data = waste_df,
                aes(x = date, y =mean_conc ),
                color = "#DB4048",
                size = 1.1) +
      geom_point(data = waste_df,
                aes(x = date, y =mean_conc ),
                color = "#DB4048",
                alpha = .5,
                size = 1.2) +
      facet_wrap(~vlabel, labeller= as_labeller(TeX, default = label_parsed)) +
      scale_fill_manual(values = pal) +
      theme_bw() +
      theme(
        legend.position = "none",
        plot.title = element_text(face = "bold", size = 16, hjust = .5),
        axis.title = element_text(size = 18),
        strip.text = element_text(size = 14)
      )+
      scale_y_continuous(sec.axis = sec_axis(~./adj,
                                         name = "Corrected Infection Estimates",
                                         labels = comma),
                         labels = comma) +
      labs(y = "Effective Concentration Rolling Average",
           title = custom_title) +
      scale_x_date(date_labels = "%b %Y") 
  }



##################
# LINE PLOTS
##################
plot_line <- function(w_data, waste_df,
                      adj, county_fips, 
                      begin_date, end_date,
                      custom_title) {
  
 

    w_data %>% 
      group_by(fips, biweek, vlabel) %>%
      summarize(date = min(date),
                exp_cases_median = unique(exp_cases_median),
                mean_conc = unique(mean_conc)) %>%
      filter(fips == county_fips  & date >= begin_date & date <= end_date) %>%
      ggplot() +
      geom_point(aes(x = date, 
                 y = exp_cases_median*adj,
                 color = vlabel),
                 alpha = .7) +
      geom_line(aes(x = date, 
                 y = exp_cases_median*adj,
                 color = vlabel),
                 alpha = 1,
                linewidth = 1.1) +
      geom_line(data = waste_df,
                aes(x = date, y =mean_conc ),
                color = "#DB4048",
                size = 0.9) +
      geom_point(data = waste_df,
                aes(x = date, y =mean_conc ),
                color = "#DB4048",
                alpha = .5,
                size = 1.1) +
      facet_wrap(~vlabel, labeller = as_labeller(TeX, default = label_parsed)) +
      scale_color_manual(values = pal) +
      theme_bw() +
      theme(
        legend.position = "none",
        plot.title = element_text(face = "bold", size = 16, hjust = .5),
        axis.title = element_text(size = 18),
        strip.text = element_text(size = 14)
      )+
      scale_y_continuous(sec.axis = sec_axis(~./adj,
                                         name = "Corrected Infection Estimates",
                                         labels = comma),
                         labels = comma) +
      labs(y = "Effective Concentration Rolling Average",
           title = custom_title) +
      scale_x_date(date_labels = "%b %Y")
}

Relationship Between Bias Corrected Counts and Wastewater Effective Concentrations

Initial Visualizations

counties_with_waste <- if(subset) c("25001","25005") else counties_with_waste

walk(counties_with_waste,
     ~ {
       plt1 <- compare_county_wastewater(county_fips =.x, w_data =wjoined, option = "ribbon")
       plt2 <- compare_county_wastewater(county_fips = .x, w_data =wjoined, option = "line")
       
       plt <- cowplot::plot_grid(plt1, plt2, nrow =1)
       print(plt)
     })

 # compare_county_wastewater(county_fips = "25005", w_data =wjoined, option="line")
compare_county_wastewater_covidestim <- function(county_fips, 
                                      end_date ="2022-02-12",
                                      w_data,
                                      option = "ribbon") {
  
  county_name <- w_data %>% 
    filter(fips == county_fips) %>%
    pull(name) %>%
    na.omit() %>%
    unique()
    
  custom_title = paste0(
    "Comparing Wastewater Concentration to Corrected Estimates:\n",
                 county_name, ", FIPS: ", county_fips)
  
  end_date <- ymd(end_date)

  
  begin_date <- w_data %>%
    filter(fips == county_fips & date <= end_date & !is.na(mean_conc)) %>%
    pull(date) %>%
    min()
  
  if(option == "ribbon") { 
    adj <- w_data %>%filter(fips == county_fips & date <= end_date) %>%
      pull(exp_cases_median) %>% max(na.rm = TRUE) }
  else {
     adj <- w_data %>%filter(fips == county_fips & date <= end_date) %>% pull(exp_cases_median) %>% max(na.rm = TRUE)
  }
  conc_max <- w_data %>%filter(fips == county_fips & date <= end_date) %>% pull(mean_conc) %>% max(na.rm=TRUE)
  
  adj <- conc_max/adj
  
  waste_dat <- w_data %>% 
    filter(fips == county_fips & date >= begin_date & date <= end_date) %>%
    select(fips, date, biweek, mean_conc) %>%
    group_by(biweek, fips) %>%
    summarize(date = min(date),
              mean_conc = unique(mean_conc)) %>%
    distinct()
  
  if(option == "ribbon") {
     w_data %>%
        plot_ribbon(waste_df = waste_dat, adj, county_fips,
                    begin_date, end_date, custom_title)
  }
  else if(option == "line") {
     w_data %>%
        plot_line(waste_df = waste_dat, adj,county_fips,
                    begin_date, end_date, custom_title)
  }
 
}



plot_ribbon <- function(w_data, waste_df,
                        adj,   county_fips, 
                        begin_date, end_date,
                        custom_title) {
 
  w_data %>%
    filter(fips == county_fips  & date >= begin_date & date <= end_date) %>%
    ggplot() +
      geom_ribbon(aes(x = date, 
                 ymin = exp_cases_lb*adj,
                 ymax = exp_cases_ub*adj,
                 fill = vlabel),
                 alpha = .5) +
      geom_line(data = waste_df,
                aes(x = date, y =mean_conc, color = 'waste'),
               # color = "#DB4048",
                size = 1.1) +
      geom_linerange(aes(xmin = mindate, 
                         xmax = maxdate, 
                         y = infections*adj,
                         color = 'covidestim')) + 
      geom_point(data = waste_df,
                aes(x = date, y =mean_conc ),
                color = "#DB4048",
                alpha = .5,
                size = 1.2) +
      facet_wrap(~vlabel, labeller= as_labeller(TeX, default = label_parsed)) +
      scale_fill_manual(values = pal, labels = c('','','',''), 
                        name = "Probabilistic Bias Intervals") +
      theme_bw() +
      theme(
    #    legend.position = "none",
        plot.title = element_text(face = "bold", size = 22, hjust = .5),
        axis.title = element_text(size = 22),
        strip.text = element_text(size = 14, color= "white"),
        legend.text = element_text(size = 20),
        legend.title = element_text(size = 25),
        strip.background = element_rect(fill = "#3E3D3D")
      )+
      scale_y_continuous(sec.axis = sec_axis(~./adj,
                                         name = "Corrected Infection Estimates",
                                         labels = comma),
                         labels = comma) +
      labs(y = "Effective Concentration Rolling Average",
           title = custom_title) +
      scale_x_date(date_labels = "%b %Y") +
    scale_color_manual(name = '', values = c('covidestim' = 'darkblue',
                                             'waste' = '#DB4048')) +
    guides(color = guide_legend(override.aes = list(size = 12, linewidth=7)),
           fill = guide_legend(override.aes =list(size = 6)))
  }
# walk(counties_with_waste,
#      ~ {
#        plt <- compare_county_wastewater_covidestim(
#          county_fips =.x, w_data =wjoined, option = "ribbon")
#        print(plt)
#      })


plotlist <- map(counties_with_waste,
     ~ {
      compare_county_wastewater_covidestim(
         county_fips =.x, w_data =wjoined, option = "ribbon") +
         theme(legend.position="none",
               axis.text.x = element_text(size = 12),
               axis.title =element_text(size= 14))
     })
  
legend_b <- cowplot::get_legend(
  plotlist[[1]] + 
    guides(color = guide_legend(
      nrow = 1, 
      override.aes = list(
      linewidth=4))) +
    theme(legend.position = "top", 
          legend.text=element_text(size =30),
          legend.title = element_text(size = 30))
)

library(cowplot)
title_gg <- ggdraw() + 
  draw_label(
    "Comparing Bias Corrected Infections with Wastewater Concentrations Over Time",  
             fontface="bold",
             x = 0,
             hjust = 0,
             size = 35)+
  theme(
    # add margin on the left of the drawing canvas,
    # so title is aligned with left edge of first plot
    plot.margin = margin(0, 0, 0, 7)
  )

plts <- cowplot::plot_grid(plotlist=plotlist, ncol =2)

cowplot::plot_grid(title_gg,
                   legend_b,
                   plts, 
                   ncol = 1 ,
                   rel_heights = c(.05, .1, .85))

#   
# ggpubr::ggarrange(plotlist=plotlist,
#                   ncol = 2,
#                   common.legend=TRUE)
# 
# plotlist$common.legend = TRUE
# plotlist$ncol = 2
# do.call(ggpubr::ggarrange, plotlist)


ggsave(here("thesis/figure/wastewater_ma_by_county.pdf"))

Cross Correlation

plot_all_ccf <- function(input_df_for_fips, 
                         varnames= c("differenced_mean_conc",
                                  "differenced_exp_median")) {
  
  versions <- input_df_for_fips %>% 
    group_split(version)
  
  message(length(versions))
  
  plt1 <- plot_ccf(versions[[1]], varnames)
  plt2 <- plot_ccf(versions[[2]], varnames)
  plt3 <- plot_ccf(versions[[3]], varnames)
  plt4 <- plot_ccf(versions[[4]], varnames)
  
  cowplot::plot_grid(plt1, plt2, plt3, plt4, nrow = 1)
  
}


plot_ccf <- function(input_df_for_version,
                     varnames = c("differenced_mean_conc",
                                  "differenced_exp_median")) {
  
  var1 <- varnames[1]
  var2 <-varnames[2]
  lab <- unique(input_df_for_version$vlabel)
  
  cross_correlations <- ccf(
    pull(input_df_for_version, var1),
     pull(input_df_for_version, var2), 
                            plot = FALSE)

  ccf_res <- tibble(lag = cross_correlations$lag[,,1], 
                  correlation=cross_correlations$acf[,,1]) 

  
  county_name <- gsub("County, MA", "", unique(input_df_for_version$name))
  
  ccf_res %>%  
    mutate(max = ifelse(correlation == max(correlation), 
                        lag, NA),
           maxcorr = round(max(correlation),3)) %>%
    ggplot(aes(x=lag, y = correlation)) +
    geom_point() +
    geom_linerange(aes(ymin = 0, ymax=correlation)) +
    theme_c() +
    theme(plot.title = element_text(size = 12, hjust = .5,
                                    margin=margin(1,1,8,1)),
          plot.margin = margin(10,1,1,1),
            plot.subtitle = element_text(size = 12, hjust = .5, 
                                         margin = margin(4,1,1,1))) +
    labs(title =TeX(paste0(
      "Cross Correlation between $(Y_t)$ and $(Z_t)$, County: ", 
      county_name)),
         subtitle = TeX(paste0("Version: ", lab))) +
    geom_label(aes(x = max, y = correlation,
                   label = paste0("Max Correlation: ",
                                  maxcorr, "\n Lag: ",
                                  max)),
               hjust = -.1) +
    ylim(-1,1) +
    scale_x_continuous(n.breaks = 10)
    
}




get_max_lag_by_version <- function(input_df_for_version,
                                   varnames = c("differenced_mean_conc",
                                  "differenced_exp_median"),
                                  comparison = "bias corrected") {
  
  var1 <- varnames[1]
  var2 <-varnames[2]
  lab <- unique(input_df_for_version$vlabel)
  
  cross_correlations <- ccf(
    pull(input_df_for_version, var1),
     pull(input_df_for_version, var2), 
                            plot = FALSE)
  
  
  # cross_correlations <- ccf(input_df_for_version$differenced_mean_conc,
  #                           input_df_for_version$differenced_exp_median, 
  #                           plot = FALSE)

  ccf_res <- tibble(lag = cross_correlations$lag[,,1], 
                  correlation=cross_correlations$acf[,,1]) 
  
  m <- ccf_res %>% filter(correlation == max(correlation))
  
  tibble(vlabel = unique(input_df_for_version$vlabel),
         fips = unique(input_df_for_version$fips),
         max_lag = m$lag,
         max_corr = m$correlation,
         comparison = comparison,
         county_name = unique(input_df_for_version$name))

}
differenced_all <- wjoined %>%
  filter(!is.na(mean_conc)) %>%
  select(-date) %>%
  distinct() %>%
  group_by(vlabel, fips) %>%
  arrange(biweek) %>%
  mutate(differenced_exp_median = exp_cases_median -
           lag(exp_cases_median, n =1),
         differenced_mean_conc = mean_conc - lag(mean_conc, n = 1),
         differenced_positive = positive - lag(positive, n = 1),
         differenced_covidestim = infections - lag(infections, n = 1)) %>%
  ungroup() %>%
  filter(!is.na(differenced_exp_median) & !is.na(differenced_mean_conc)) %>%
  ungroup()

Cross Correlation (Wastewater and Covidestim Estimates)

# glimpse(differenced_all)

fips_sub <- if(subset) c("25001","25005") else unique(differenced_all$fips)

differenced_all_sub <- differenced_all %>%
  filter(fips %in% fips_sub)

differenced_all_sub %>% 
  filter(version == "v1" & !is.na(differenced_covidestim)) %>%
  group_by(fips) %>%
  group_split() %>% 
  walk(~ { 
    plt <- .x %>% plot_ccf(varnames= c("differenced_mean_conc",
                                  "differenced_covidestim"))
    print(plt)})

Cross Correlation (Wastewater and Observed Cases)

# cross correlation plots for each county
differenced_all_sub %>% 
  filter(version == "v1") %>%
  group_by(fips) %>%
  group_split() %>% 
  walk(~ { 
    plt <- .x %>% plot_ccf(varnames= c("differenced_mean_conc",
                                  "differenced_positive"))
    print(plt)})

Cross Correlation (Wastewater and Bias-Corrected Cases)

# cross correlation plots for each county
differenced_all_sub %>% 
  group_by(fips) %>%
  group_split() %>%
  walk(~ { 
    plt <- .x %>% plot_all_ccf()
    print(plt)})

Compare Lags with Maximum Correlation Between Covidestim, Wastewater, and Bias Corrected Counts

Compare Lags Among versions of Bias Corrected Counts

For each county, lag with the maximum correlation:

# lags for bias-corrected counts
lags_all_pb <- differenced_all %>%
  group_by(fips, vlabel) %>%
  group_split() %>%
  map_df(get_max_lag_by_version)


lags_all_pb %>%
  group_by(fips) %>%
  summarize(`Lags with Maximum Correlation in Some Version` =
              paste0(unique(max_lag), collapse =", "))
# lags for observed counts
lags_all_observed <- differenced_all %>% 
  filter(version == "v1") %>%
  group_by(fips) %>%
  group_split()%>%
  map_df(~get_max_lag_by_version(
    input_df_for_version = .x,
    varnames= c("differenced_mean_conc",
                "differenced_positive"),
   comparison = "observed" )) %>%
  select(-vlabel)


cols <- c("red", "#10BAC5", "#1B10C5", "#EFB719", "#900C3F")
cols <- c("red",pal)
names(cols ) <- c("observed", unique(lags_all_pb$vlabel))


# compare pb max correlations to those for observed counts 
lags_all_pb %>%
  bind_rows(lags_all_observed) %>%
   mutate(source = ifelse(!is.na(vlabel), vlabel, comparison)) %>%
  group_by(fips) %>%
  mutate(m = median(max_corr)) %>%
  ungroup() %>%
  mutate(max_lag=factor(max_lag)) %>%
  ggplot(aes(x=fct_reorder(fips,m, 
                           .desc = TRUE),
             y = max_corr,
             color = source, 
             shape = max_lag)) +
  geom_jitter(aes(fill=source),color="black",
              height =0, width =.1, size = 2) +
  scale_shape_manual(values=c(21,22,24)) +
  scale_fill_manual(values=cols,
                    name = 'Time Series Compared with\nWastewater:',
                      labels = TeX(c('$observed$',
                                   lags_all_pb$vlabel))) +
  theme_bw() +
  theme_c() +
  theme(
    # axis.text.x = element_text(angle  = 60, vjust = .4),
    axis.text.x = element_text(angle  = 20, vjust = .4),
    axis.title.x = element_text(margin = margin(3,3,3,3)),
        legend.text = element_text(size = 14),
        axis.text.y = element_text(size = 11),
        legend.title = element_text(face="bold", size = 16)) +
  labs(y = "Maximum Correlation") +
  scale_size_discrete(breaks = c(1.9,2),
                      range = c(1.9,2),
                      labels = c("bias corrected", "observed")) +
  labs(x = "County FIPS",
       y = "Maximum Correlation",
       title = "Comparing Maximum Correlation of Bias Corrected Counts with Wastewater\nto that Between Observed Cases and Wastewater",
       shape = "Maximum Lag:",
       color = "Time Series Compared with Wastewater:") +
  guides(shape = guide_legend(override.aes = list(size = 4)),
         fill= guide_legend(override.aes = list(size = 2, shape=24))) 

ggsave(here("thesis/figure/correlation_observed_pb.jpeg"))



# max series by rank; use for following comparison
max_series_by_rank <- lags_all_pb %>%
  group_by(fips) %>%
  mutate(r = rank(max_corr),
            series = vlabel[which.max(r)]) %>%
  group_by(series) %>%
  summarize(n=n()) %>%
  filter(n==max(n)) %>%
  pull(series)

lags_all_pb %>%
  group_by(fips,county_name) %>%
  mutate(r = rank(max_corr),
            series = vlabel[which.max(r)]) %>%
  select(series,fips,county_name) %>%
  distinct() %>%
  group_by(series) %>%
  summarize(n=n(),
            county_name = paste(county_name,collapse=", "),
            fips =  paste(fips,collapse=", ")) 
lags_all_pb %>%
  group_by(fips,county_name) %>%
  mutate(r = rank(max_corr),
            series = vlabel[which.max(r)],
         lag = max_lag[which.max(r)]) %>%
  select(lag,fips,county_name) %>%
  distinct() %>%
  group_by(lag) %>%
  summarize(n=n(),
            county_name = paste(county_name,collapse=", "),
            fips =  paste(fips,collapse=", ")) 
 lags_all_pb %>%
  bind_rows(lags_all_observed)%>%
  group_by(fips,county_name) %>%
  mutate(r = rank(max_corr),
            comp = comparison[which.max(r)]) %>%
  select(fips,county_name,comp ) %>%
  distinct() %>%
  group_by(comp) %>%
  summarize(n=n(),
            county_name = paste(county_name,collapse=", "),
            fips =  paste(fips,collapse=", ")) 
 lags_all_observed %>% 
   group_by(fips,county_name) %>%
  mutate(r = rank(max_corr),
         lag = max_lag[which.max(r)]) %>%
  distinct() %>%
  group_by(lag) %>%
  summarize(n=n(),
            county_name = paste(county_name,collapse=", "),
            fips =  paste(fips,collapse=", "))
 lags_all_pb %>% 
   group_by(fips,county_name) %>%
  mutate(r = rank(max_corr),
         lag = max_lag[which.max(r)]) %>%
    select(fips,county_name,lag ) %>%
  distinct() %>%
  group_by(lag) %>%
  summarize(n=n(),
            county_name = paste(county_name,collapse=", "),
            fips =  paste(fips,collapse=", "))
# lags_all_pb %>%
#   mutate(max_lag=factor(max_lag)) %>%
#   ggplot(aes(x=fips, y = max_corr,
#              color = vlabel, 
#              shape = max_lag)) +
#   geom_point(size = 2) +
#   viridis::scale_color_viridis(option ="magma", 
#                                discrete=TRUE,
#                                end=.95, 
#                                labels = TeX(lags_all_pb$vlabel)) +
#   theme_bw() +
#   labs(y = "Maximum Correlation") 



fips_order <- lags_all_pb %>%
  bind_rows(lags_all_observed) %>%
   mutate(source = ifelse(!is.na(vlabel), vlabel, comparison)) %>%
  group_by(fips) %>%
  mutate(m = median(max_corr)) %>%
  ungroup() %>%
  select(m,fips) %>%
  arrange(desc(m)) %>%
  distinct() %>%
  pull(fips)
# max lags for covidestim
lags_all_covidestim <- differenced_all %>% 
  filter(fips != "25019") %>%
  filter(version == "v1") %>%
  group_by(fips) %>%
  group_split()%>%
  map_df(~get_max_lag_by_version(
    input_df_for_version = .x,
    varnames= c("differenced_mean_conc",
                "differenced_covidestim"),
    comparison = "Covidestim"))


lags_all <- lags_all_pb %>%
  filter(vlabel == max_series_by_rank) %>%
  bind_rows(lags_all_observed) %>%
  bind_rows(lags_all_covidestim) 


pal <- c("red", "#EFB719", "#0C6B90")
names(pal) <- c("observed", "bias corrected", "Covidestim")


# plot in same order as the previous 
lags_all %>% 
  mutate(fips = factor(fips, levels = fips_order),
         max_lag = factor(max_lag)) %>%
  ggplot(aes(x=fips,
             y = max_corr,
             fill = comparison, 
             shape = max_lag)) +
  geom_jitter(height=0,width =.1, size =2, color = "black") +
  scale_fill_manual('Time Series Compared with\nWastewater:',
                    values=pal,
                      labels = TeX(c('$observed$',
                                     '$Covidestim$',
                                   max_series_by_rank))) +
   scale_shape_manual(values=c(21,22,24)) +
  theme_bw() +
  theme_c() +
  theme(plot.title = element_text(face="bold", size =14, hjust=.5),
    # axis.text.x = element_text(angle  = 60, vjust = .4),
    axis.text.x = element_text(angle  = 20, vjust = .4),
    axis.title.x = element_text(margin = margin(3,3,3,3)),
        legend.text = element_text(size = 14),
        axis.text.y = element_text(size = 11),
        legend.title = element_text(face="bold", size = 16)) +
  labs(y = "Maximum Correlation") +
  labs(x = "County FIPS",
       y = "Maximum Correlation",
       title = paste0(
         "Comparing Maximum Correlation of Bias Corrected, Covidestim, and Observed Counts\nwith Wasterwater Concentrations"),
       shape = "Maximum Lag:",
       color = "Time Series Compared with Wastewater:") +
  guides(shape = guide_legend(override.aes = list(size = 4)),
         fill= guide_legend(override.aes = list(size = 2, shape=24)))

ggsave(here("thesis/figure/correlation_observed_pb_covidestim.jpeg"))

Maximum at lag = -1 \(\implies\) wastewater leads cases by 1 biweek.

county_names <- wjoined %>%
  select(name, fips) %>% 
  distinct() %>%
  filter(!is.na(name)) %>%
  filter(fips %in% lags_all_pb$fips)

  
  
lags_all_pb %>%
  left_join(county_names) %>%
  mutate(name = gsub(" County, MA", "", name)) %>%
  mutate(
         max_lag = paste0("Lag: ", max_lag)) %>%
  ggplot(aes(x = name, y = max_corr, fill = max_lag)) +
  geom_bar(stat="identity",
           position = "dodge",
           show.legend=FALSE) + 
  facet_wrap(~vlabel+max_lag, 
             labeller=as_labeller(
               latex2exp::TeX, default = label_parsed),
             ncol = 2,
             scales= "free_x") +
  theme_c() +
  theme(axis.text.x = element_text(angle = 30, vjust = .6)) +
  viridis::scale_fill_viridis(discrete=TRUE,
                              option="mako",
                              begin = .4,
                              end = .8)

# Note on thesis - add a couple examples on why beta kernel and gaussian kernel perform relatively similarly in this case (i.e. induced distributions).